Skip to content

ESM2 NVFP4 and MXFP8 support and documentation update.#1484

Open
jomitchellnv wants to merge 10 commits intomainfrom
jm/mxfp8-nvfp4-mr-02272026
Open

ESM2 NVFP4 and MXFP8 support and documentation update.#1484
jomitchellnv wants to merge 10 commits intomainfrom
jm/mxfp8-nvfp4-mr-02272026

Conversation

@jomitchellnv
Copy link
Collaborator

@jomitchellnv jomitchellnv commented Feb 27, 2026

Layer-wise MXFP8/NVFP4 precision for ESM-2 TransformerEngine training

Adds support for per-layer quantization precision control, enabling mixed FP8/FP4/BF16
configurations across transformer layers during training. This allows users to assign
different quantization formats to different layers via Hydra config (1-indexed fp8_layers
and fp4_layers lists), enabling convergence/performance tradeoff exploration.

Key changes:

  • Per-layer quantization context in encoder forward: NVEsmEncoder now maintains a
    layer_number_quantized_recipe_map that selects the appropriate TE autocast context per
    layer (nullcontext for FP8 to respect outer autocast, explicit autocast for FP4, or
    autocast(enabled=False) for BF16).
  • quantization.py: New utilities for resolving layer-wise quantization assignments
    (resolve_quantization_layers), generating debug API regex patterns
    (generate_layer_regex), and initializing nvdlfw_inspect quant stats logging
    (initialize_quant_stats_logging). Handles 0-indexed (model internals) and 1-indexed
    (user-facing) layer numbering.
  • train_ddp.py / train_fsdp2.py: Integrated layer-wise quantization setup --
    resolves layer assignments from config, builds recipe map, assigns to encoder, and
    optionally initializes quant stats logging.
  • train_fsdp2_cp.py: Switched from AutoConfig/AutoModelForMaskedLM to local
    NVEsmConfig/NVEsmForMaskedLM for consistency and to avoid remote code trust issues.
  • Hydra config (defaults.yaml): Added fp4_config, quant_stats_config,
    fp8_layers, fp4_layers, and use_fp32_master_weights settings.
  • Model files: Updated esm_nv.py across all checkpoint directories (native_te,
    accelerate, peft) and the models package with layer-wise quantization support, NVTX
    annotations per encoder layer, and FP8_RECIPES/FP4_RECIPES type constants.
  • Tests: Added comprehensive tests for resolve_quantization_layers,
    generate_layer_regex, and update_quant_stats_config covering defaults, explicit
    layers, mixed assignments, overlap validation, and edge cases.### Description

Usage

TODO: Add code snippet

Type of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Refactor
  • Documentation update
  • Other (please describe):

CI Pipeline Configuration

Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run.

  • ciflow:skip - Skip all CI tests for this PR
  • ciflow:notebooks - Run Jupyter notebooks execution tests for bionemo2
  • ciflow:slow - Run slow single GPU integration tests marked as @pytest.mark.slow for bionemo2
  • ciflow:all - Run all tests (unit tests, slow tests, and notebooks) for bionemo2. This label can be used to enforce running tests for all bionemo2.
  • ciflow:all-recipes - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes.

Unit tests marked as @pytest.mark.multi_gpu or @pytest.mark.distributed are not run in the PR pipeline.

For more details, see CONTRIBUTING

Note

By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage.

Authorizing CI Runs

We use copy-pr-bot to manage authorization of CI
runs on NVIDIA's compute resources.

  • If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will
    automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123)
  • If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an
    /ok to test comment on the pull request to trigger CI. This will need to be done for each new commit.

Triggering Code Rabbit AI Review

To trigger a code review from code rabbit, comment on a pull request with one of these commands:

See https://docs.coderabbit.ai/reference/review-commands for a full list of commands.

Pre-submit Checklist

  • I have tested these changes locally
  • I have updated the documentation accordingly
  • I have added/updated tests as needed
  • All existing tests pass successfully

Summary by CodeRabbit

Release Notes

  • New Features

    • Layer-wise FP8/FP4 quantization control for ESM2 model training
    • Quantization statistics debugging with tensor-level monitoring for FP4/FP8
    • Tokenizer revision parameter for flexible dataset loading
    • TransformerEngine-optimized model variants
  • Documentation

    • Expanded low-precision training guide covering FP8, MXFP8, and NVFP4
    • Added quantization statistics debugging section with examples
    • Updated performance benchmarks for quantized training scenarios
  • Configuration

    • New FP4 quantization settings block
    • Unified quantization statistics configuration replacing FP8-specific options
    • Support for specifying per-layer quantization precision

@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 27, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 27, 2026

Important

Review skipped

Auto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: f7906621-2484-432d-8ca6-ff35d28138b5

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Introduces per-layer FP8/FP4 quantization support for TransformerEngine-accelerated ESM2 models via new TE-optimized model classes, layer-wise quantization resolution utilities, per-layer autocast contexts, NVTX instrumentation, and updates to five training strategies (DDP, DDP+CP, FSDP2, FSDP2+CP, mFSDP).

Changes

Cohort / File(s) Summary
TE-Optimized Core Modeling
bionemo-recipes/models/esm2/modeling_esm_te.py, bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py
New TE-enabled ESM model classes (NVEsmConfig, NVEsmEncoder, NVEsmModel, NVEsmForMaskedLM, NVEsmLMHead, NVEsmEmbeddings, NVEsmForTokenClassification) with per-layer quantization, rotary embeddings, and NVTX profiling; initialize_quantization and get_layer_autocast methods manage layer-specific FP8/FP4/BF16 precision via autocast contexts.
Quantization Utilities
bionemo-recipes/recipes/esm2_native_te/quantization.py
New quantization module providing generate_layer_regex, update_quant_stats_config, initialize_quant_stats_logging, QuantizationLayers data holder, and resolve_quantization_layers resolver for layer-wise FP8/FP4 configuration with validation and dynamic stats config generation.
Training Script Integration (DDP/CP)
bionemo-recipes/recipes/esm2_native_te/train_ddp.py, bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py
Replaced AutoConfig/AutoModelForMaskedLM with NVEsmConfig/NVEsmForMaskedLM; added per-layer quantization resolution via resolve_quantization_layers; integrated fp8/fp4 recipe creation and encoder.initialize_quantization; updated quant_stats logging based on quant_stats_config.
Training Script Integration (FSDP2)
bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py, bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py
Updated config/model initialization to NVEsmConfig.from_pretrained and NVEsmForMaskedLM; added per-layer quantization resolution and initialization on encoder; conditional fp4_recipe creation; adjusted mixed-precision policy and debug hooks to use quant_stats_config.
Training Script Integration (mFSDP)
bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py
Migrated to NVEsmConfig/NVEsmForMaskedLM; added FP32 master weights guard for mFSDP; integrated quantization layer resolution and per-layer encoder quantization initialization.
Example Encoder Implementations
bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py, bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py, bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py
Added FP8_RECIPES and FP4_RECIPES constants; introduced initialize_quantization and get_layer_autocast on NVEsmEncoder; wrapped per-layer forward execution with per-layer autocast contexts and NVTX ranges; added imports for nvtx and nullcontext.
Configuration Files
bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml
Renamed fp8_stats_config to quant_stats_config; added new fp4_config block with fp4_recipe, fp4_format, and fp4_recipe_kwargs; added dataset.tokenizer_revision field; introduced fp8_layers and fp4_layers fields; updated use_fp32_master_weights to null.
Quantization Stats Configurations
bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml, bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml
New FP4 stats config with two example collections for layers 1-5 and 6-10; expanded FP8 stats to include LogTensorStats alongside existing LogFp8TensorStats with additional layer types (proj, fc1, fc2).
Dataset & Profiling Updates
bionemo-recipes/recipes/esm2_native_te/dataset.py, bionemo-recipes/recipes/esm2_native_te/perf_logger.py, bionemo-recipes/recipes/esm2_native_te/requirements.txt
Added tokenizer_revision parameter to dataset creation functions; replaced fp8_stats_enabled flag with quant_stats_config in PerfLogger; pinned transformers to version 4.57.3.
Test Coverage
bionemo-recipes/models/esm2/tests/test_layer_quantization.py, bionemo-recipes/recipes/esm2_native_te/tests/test_quantization.py, bionemo-recipes/recipes/esm2_native_te/tests/test_train.py, bionemo-recipes/recipes/esm2_native_te/tests/test_distributed_checkpointing.py
New test module for NVEsmEncoder quantization and autocast contexts; comprehensive quantization utility tests covering layer regex generation, stats config updates, and layer resolution; updated train tests to use quant_stats_config; removed esm_nv.py from distributed checkpointing expectations.
Documentation & Settings
bionemo-recipes/recipes/esm2_native_te/README.md, .vscode/settings.json
Expanded README with FP8/MXFP8/NVFP4 training documentation, layer-wise precision control, quantization stats debugging, and performance benchmarks; added VSCode logs exclusion and formatting settings.
Docker Configuration
bionemo-recipes/recipes/esm2_native_te/.dockerignore
Updated ignore entries to exclude Dockerfile.\*, quantization profiling artifacts (nsight_profiling, memory_snapshots), logging directories, and development caches while preserving checkpoint/output handling.

Sequence Diagram

sequenceDiagram
    participant User
    participant TrainScript as Train Script<br/>(DDP/FSDP2/mFSDP)
    participant QuantUtil as Quantization<br/>Utils
    participant Modeling as NVEsmConfig &<br/>NVEsmForMaskedLM
    participant Encoder as NVEsmEncoder
    participant TEAutocast as TE Autocast<br/>Context
    participant Forward as Forward<br/>Pass

    User->>TrainScript: Launch training with fp8/fp4 config
    TrainScript->>QuantUtil: resolve_quantization_layers(num_layers, fp8_enabled, fp4_enabled, fp8_layers, fp4_layers)
    QuantUtil-->>TrainScript: QuantizationLayers{fp8_0idx, fp4_0idx, ...}
    
    TrainScript->>Modeling: NVEsmConfig.from_pretrained(model_tag)
    Modeling-->>TrainScript: config
    
    TrainScript->>Modeling: NVEsmForMaskedLM(config)
    Modeling->>Encoder: create NVEsmEncoder instance
    Encoder-->>Modeling: encoder
    Modeling-->>TrainScript: model
    
    TrainScript->>Encoder: initialize_quantization(fp8_layers, fp4_layers, fp8_recipe, fp4_recipe)
    Encoder->>Encoder: build _layer_precision map {layer_idx: 'fp8'|'fp4'|None}
    Encoder->>Encoder: store _fp8_recipe, _fp4_recipe
    Encoder-->>TrainScript: initialized
    
    TrainScript->>Forward: forward(input_ids, attention_mask)
    Forward->>Encoder: for each layer_idx in layers
    Encoder->>Encoder: get_layer_autocast(layer_idx)
    alt Layer is FP8
        Encoder-->>TEAutocast: nullcontext
    else Layer is FP4
        Encoder-->>TEAutocast: autocast(enabled=True, recipe=fp4_recipe)
    else Layer is BF16/None
        Encoder-->>TEAutocast: autocast(enabled=False)
    end
    TEAutocast->>Forward: execute layer within context
    Forward->>Forward: nvtx.range_push("encoder_layer_N")
    Forward->>Forward: hidden_states = layer(hidden_states, ...)
    Forward->>Forward: nvtx.range_pop()
    Forward-->>Encoder: hidden_states
    Encoder-->>Forward: processed output
    Forward-->>TrainScript: logits, loss
    
    TrainScript->>QuantUtil: initialize_quant_stats_logging(if enabled)
    QuantUtil->>QuantUtil: update_quant_stats_config with resolved layers
    QuantUtil-->>TrainScript: debug API initialized
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Poem

🐰 Quantization hops through layers with precision keen,
FP8 and FP4 recipes dance in between,
With NVTX ranges marking each stride,
Per-layer contexts guide the forward glide,
Transformer Engine accelerates with care,
ESM2 trains faster in the TE air!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Description check ❓ Inconclusive The PR description provides substantial context about layer-wise quantization, key changes to encoder/training scripts/configs, but the template sections (Usage, Type of changes, CI configuration, Checklist) are incomplete with placeholder text and unchecked boxes. Complete the template sections: provide a usage code snippet, mark the type of changes (likely New feature and Documentation update), and check applicable pre-submit checklist items.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'ESM2 NVFP4 and MXFP8 support and documentation update' clearly summarizes the main changes: adding NVFP4 and MXFP8 quantization support plus documentation updates.
Docstring Coverage ✅ Passed Docstring coverage is 86.79% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch jm/mxfp8-nvfp4-mr-02272026

Comment @coderabbitai help to get the list of available commands and usage tips.

Jonathan Mitchell added 2 commits March 2, 2026 11:50
- includes capability to log out stats for MXFP8
and NVFP4 at the same time
- Enables layer-wise precision setting

Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1429.ipp1a1.colossus.nvidia.com>
@jomitchellnv jomitchellnv force-pushed the jm/mxfp8-nvfp4-mr-02272026 branch from 091299c to 734a25a Compare March 2, 2026 19:55
Jonathan Mitchell and others added 2 commits March 2, 2026 15:18
Signed-off-by: Jonathan Mitchell <jomitchell@ipp1-1428.ipp1a1.colossus.nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
@jomitchellnv
Copy link
Collaborator Author

@coderabbitai review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 3, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

🧹 Nitpick comments (11)
bionemo-recipes/recipes/esm2_native_te/.dockerignore (1)

33-34: Consider anchoring the scratch path ignore pattern.

j/ on Line 34 matches any directory named j at any depth. If this is intended as a repo-root local scratch dir, prefer /j/ to avoid accidental exclusions in nested paths.

Proposed tweak
- j/
+ /j/
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/.dockerignore` around lines 33 - 34,
The ignore pattern "j/" in .dockerignore matches any directory named "j" at any
depth; change it to "/j/" to anchor it to the repository root so only the
top-level scratch dir is ignored. Update the pattern "j/" to "/j/" (preserve the
trailing slash) in the .dockerignore entry to avoid accidentally excluding
nested directories named "j".
bionemo-recipes/recipes/esm2_native_te/tests/test_train.py (1)

146-158: Config key rename looks correct; consider updating test naming for consistency.

The config keys are correctly updated from fp8_stats_config to quant_stats_config. However, the test function name (test_sanity_ddp_fp8_stats_logging), docstring ("FP8 stats logging"), and variable names (fp8_log_dir) still reference FP8 specifically.

Since quant_stats_config now supports both FP8 and FP4, consider renaming for clarity:

  • test_sanity_ddp_quant_stats_logging
  • quant_log_dir variable

This is optional since the test still validates FP8 stats specifically.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/tests/test_train.py` around lines 146
- 158, Rename test identifiers and docstring to reflect the config rename from
fp8_stats_config to quant_stats_config: update the test function name
test_sanity_ddp_fp8_stats_logging to test_sanity_ddp_quant_stats_logging, change
the fp8_log_dir variable to quant_log_dir (or similar), and update the docstring
"FP8 stats logging" to "quant stats logging" while keeping all uses of
quant_stats_config and the existing assertions intact so the test still
validates FP8 behavior under the new config name.
bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py (1)

60-67: Consider making FP4_RECIPES a tuple for consistency with FP8_RECIPES.

FP8_RECIPES is defined as a tuple of classes, but FP4_RECIPES is a single class. While this works with isinstance(), making it a tuple would be more consistent and future-proof if additional FP4 recipes are added.

♻️ Suggested change for consistency
-FP4_RECIPES = transformer_engine.common.recipe.NVFP4BlockScaling
+FP4_RECIPES = (transformer_engine.common.recipe.NVFP4BlockScaling,)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py` around
lines 60 - 67, FP4_RECIPES is assigned a single class
(transformer_engine.common.recipe.NVFP4BlockScaling) while FP8_RECIPES is a
tuple; make FP4_RECIPES a tuple for consistency and to allow adding more entries
later — update the FP4_RECIPES assignment to use a tuple containing
NVFP4BlockScaling (e.g., FP4_RECIPES =
(transformer_engine.common.recipe.NVFP4BlockScaling,)) so code that expects a
sequence of recipe classes (similar to FP8_RECIPES) will work uniformly.
bionemo-recipes/models/esm2/modeling_esm_te.py (1)

60-67: Same consistency suggestion: consider making FP4_RECIPES a tuple.

This matches the same pattern seen in the checkpoint esm_nv.py files. For consistency across the codebase, consider using a tuple.

♻️ Suggested change
-FP4_RECIPES = transformer_engine.common.recipe.NVFP4BlockScaling
+FP4_RECIPES = (transformer_engine.common.recipe.NVFP4BlockScaling,)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/esm2/modeling_esm_te.py` around lines 60 - 67,
FP4_RECIPES is defined as a single value whereas FP8_RECIPES is a tuple; make
FP4_RECIPES a tuple for consistency by wrapping
transformer_engine.common.recipe.NVFP4BlockScaling in a one-element tuple (use a
trailing comma) so the constant mirrors FP8_RECIPES and any tuple-based handling
of recipes (reference FP4_RECIPES and FP8_RECIPES names).
bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml (1)

9-12: Minor formatting inconsistency: missing blank line before WandB section.

Other config files (L1_3B.yaml, L1_15B_perf_test.yaml) have a blank line between the dataset section and the wandb_init_args comment. Consider adding one for consistency.

♻️ Suggested formatting fix
 dataset:
   micro_batch_size: 4
   tokenizer_revision: null
+
 # WandB config
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml` around
lines 9 - 12, Add a blank line between the dataset block and the WandB section
to match other config files: locate the dataset section (keys dataset,
micro_batch_size, tokenizer_revision) and insert a single empty line before the
existing WandB comment/section (the "# WandB config" or the wandb_init_args
section) so formatting is consistent with L1_3B.yaml and L1_15B_perf_test.yaml.
bionemo-recipes/recipes/esm2_native_te/dataset.py (1)

60-62: Simplify redundant conditional for revision parameter.

The expression revision=tokenizer_revision if tokenizer_revision else None is equivalent to just revision=tokenizer_revision since both None and empty string are falsy and the default is None anyway.

♻️ Simplify revision parameter
     tokenizer = AutoTokenizer.from_pretrained(
-        tokenizer_name, revision=tokenizer_revision if tokenizer_revision else None
+        tokenizer_name, revision=tokenizer_revision
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/dataset.py` around lines 60 - 62, The
call to AutoTokenizer.from_pretrained uses a redundant conditional for the
revision argument; replace revision=tokenizer_revision if tokenizer_revision
else None with simply revision=tokenizer_revision in the
AutoTokenizer.from_pretrained(...) call so the revision parameter directly uses
the tokenizer_revision variable.
bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml (1)

104-107: Consider adding validation comments for use_fp32_master_weights: null.

Setting use_fp32_master_weights to null rather than a boolean may be intentional (e.g., to require explicit user configuration), but the training script at line 125 uses args.use_fp32_master_weights in a conditional. Verify that null is handled correctly (treated as falsy).

💡 Consider adding a comment explaining the null default
-# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime.
+# Note: The layers are going to come in 1-indexed and we convert them to 0-indexed at runtime.
 fp8_layers: null
 fp4_layers: null
+# Set explicitly to true/false. When null, defaults to false behavior.
 use_fp32_master_weights: null
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml` around
lines 104 - 107, The default for use_fp32_master_weights is set to null which
may be unexpected when later checked as args.use_fp32_master_weights; update the
defaults.yaml to document that null is intentional and treated as falsy (or
change the default to false) and add a short comment next to
use_fp32_master_weights explaining that the training script expects a
boolean-like value and that null will be treated as false by the conditional
using args.use_fp32_master_weights; ensure any consumers (e.g., the code that
checks args.use_fp32_master_weights) handle null explicitly if you want
different behavior.
bionemo-recipes/recipes/esm2_native_te/quantization.py (2)

61-62: Specify explicit encoding when opening files.

Opening files without explicit encoding can lead to platform-dependent behavior. Specify encoding="utf-8" for consistent behavior.

♻️ Add explicit encoding
-    with open(config_file, "r") as f:
+    with open(config_file, "r", encoding="utf-8") as f:
         config = yaml.safe_load(f)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/quantization.py` around lines 61 - 62,
The file open call using config_file should specify an explicit encoding to
avoid platform-dependent behavior; update the open(...) in the block that reads
the YAML (the with open(config_file, "r") as f: config = yaml.safe_load(f)
statement) to include encoding="utf-8" so the YAML is read consistently across
platforms.

136-157: Consider using @dataclass for QuantizationLayers.

The class is essentially a data container with no methods. Using @dataclass would reduce boilerplate and provide __repr__, __eq__, etc. automatically.

♻️ Convert to dataclass
+from dataclasses import dataclass
+
+@dataclass
 class QuantizationLayers:
     """Resolved layer-wise quantization assignments.

     Attributes:
         fp8_layers_0indexed: 0-indexed FP8 layer numbers (for model internals), or None.
         fp4_layers_0indexed: 0-indexed FP4 layer numbers (for model internals), or None.
         fp8_layers_1indexed: 1-indexed FP8 layer numbers (for user-facing logs / quant stats), or None.
         fp4_layers_1indexed: 1-indexed FP4 layer numbers (for user-facing logs / quant stats), or None.
     """

-    def __init__(
-        self,
-        fp8_layers_0indexed: list[int] | None,
-        fp4_layers_0indexed: list[int] | None,
-        fp8_layers_1indexed: list[int] | None,
-        fp4_layers_1indexed: list[int] | None,
-    ):
-        """Initialize QuantizationLayers with the resolved layer assignments."""
-        self.fp8_layers_0indexed = fp8_layers_0indexed
-        self.fp4_layers_0indexed = fp4_layers_0indexed
-        self.fp8_layers_1indexed = fp8_layers_1indexed
-        self.fp4_layers_1indexed = fp4_layers_1indexed
+    fp8_layers_0indexed: list[int] | None
+    fp4_layers_0indexed: list[int] | None
+    fp8_layers_1indexed: list[int] | None
+    fp4_layers_1indexed: list[int] | None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/quantization.py` around lines 136 -
157, Replace the manual boilerplate class QuantizationLayers with a dataclass:
add "from dataclasses import dataclass" and annotate the class with `@dataclass`,
convert the four constructor args (fp8_layers_0indexed, fp4_layers_0indexed,
fp8_layers_1indexed, fp4_layers_1indexed) to dataclass fields using
Optional[list[int]] types (or list[int] | None) and remove the explicit
__init__; keep the existing class docstring and attribute names so __repr__ and
__eq__ are provided automatically.
bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py (2)

604-609: Hardcoded mask_ratio_train = 0.15 * 0.8 is documented but not configurable.

The comment explains this matches ESM training, which is appropriate. However, consider extracting to a class constant for clarity and potential future configurability.

💡 Extract to class constant
 class NVEsmEmbeddings(nn.Module):
     """Modified version of EsmEmbeddings to support THD inputs."""
+    
+    # Hardcoded mask ratio used in all ESM model training runs (0.15 * 0.8)
+    _MASK_RATIO_TRAIN = 0.12

     def _apply_token_dropout_bshd(self, embeddings, input_ids, attention_mask):
         ...
-        mask_ratio_train = 0.15 * 0.8  # Hardcoded as the ratio used in all ESM model training runs
+        mask_ratio_train = self._MASK_RATIO_TRAIN

Also applies to: 625-637

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py` around lines 604 -
609, Replace the hardcoded mask ratio literal by a named class constant and use
it wherever needed: define a class-level constant (e.g., MASK_RATIO_TRAIN = 0.15
* 0.8) on the model class, then update the local variable usage in the method
that currently sets mask_ratio_train and in the similar block at lines 625-637
to read from that constant (self.MASK_RATIO_TRAIN or ClassName.MASK_RATIO_TRAIN)
so scale_factor and embedding scaling use the named constant for clarity and
potential configurability.

228-234: Address TODO: Create unit test for per-layer FP context selection.

The TODO at line 234 notes the need to verify and test this logic. The FP context selection logic (BF16 vs FP8 vs FP4) is critical for correctness.

Would you like me to generate a unit test skeleton for verifying the per-layer FP context selection logic?

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py` around lines 228 -
234, Extract the per-layer FP context selection into a small function (e.g.,
get_fp_context(fp_recipe)) that returns nullcontext when fp_recipe is in
FP8_RECIPES, returns transformer_engine.pytorch.autocast(enabled=True,
recipe=fp_recipe) when fp_recipe is in FP4_RECIPES, and returns
transformer_engine.pytorch.autocast(enabled=False) otherwise; then add a unit
test file that parametrizes several fp_recipe values (members of FP8_RECIPES,
FP4_RECIPES, and a default/None case) and asserts that get_fp_context returns
the expected context object type/behavior (e.g., is nullcontext for FP8 cases
and an autocast context for FP4/default) to validate the per-layer FP context
selection logic.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@bionemo-recipes/models/esm2/modeling_esm_te.py`:
- Around line 214-234: Add unit tests that exercise the per-layer FP context
selection in the loop that iterates self.layers and reads
self.layer_number_quantized_recipe_map: create three scenarios where the mapped
fp_recipe for a given layer is (a) an instance of FP8_RECIPES, (b) an instance
of FP4_RECIPES, and (c) None/other; for each case assert that fp_context becomes
nullcontext() for FP8, transformer_engine.pytorch.autocast(enabled=True,
recipe=fp_recipe) for FP4, and
transformer_engine.pytorch.autocast(enabled=False) for the default/BF16 path.
Use lightweight mocks/monkeypatching to replace
transformer_engine.pytorch.autocast with a stub that records its args/returns so
you can assert enabled and recipe values, and construct minimal model instances
(or unit-test the loop function directly) that set
layer_number_quantized_recipe_map and self.layers to trigger each branch; also
include a test that output_hidden_states True appends hidden_states to
all_hidden_states.

In `@bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py`:
- Line 234: The TODO on the forward path flags unverified precision-routing
behavior; replace it by adding focused unit tests that exercise the precision
routing logic for FP8, FP4 and BF16 (e.g., create tests that run the module's
forward pass with tensors/configs that should route through each precision
branch and assert the correct branch was used and outputs match expected
numerical/shape properties), and then remove the TODO comment. Target the
functions and code paths that implement precision routing (the forward path /
precision-routing conditional blocks referenced by the TODO in esm_nv.py) and
add parametrized tests that cover boundary cases and mixed precision
combinations to ensure deterministic routing.

In `@bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml`:
- Around line 21-23: Update the comment describing layer ranges so it matches
the actual regex stored in layer_name_regex_pattern; the current comment
mentions layers 0-4 / 1-5 but the regex
'model\.esm\.encoder\.layers\.([6-9]|10)\..*(layernorm_qkv|proj|fc1|fc2)'
targets layers 6–10, so change the comment to state layers 6-10 (or 7-11 if
using 1-indexed wording) to prevent confusion and ensure the comment and the
pattern are consistent.

In `@bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml`:
- Around line 19-23: Replace the invalid tensor type "fprop" in the
LogTensorStats block with the valid TransformerEngine tensor type "activation"
(modify the tensors: [dgrad, wgrad, fprop] entry to tensors: [dgrad, wgrad,
activation]); keep stats and freq as-is, or optionally refactor LogTensorStats
to use tensors_struct for per-tensor configs if you need different stats per
tensor type.

In `@bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml`:
- Around line 55-61: Either remove the unused fp4_model_init_kwargs entry from
the fp4_config block in defaults.yaml, or implement FP4 model initialization in
train_fsdp2.py to mirror the FP8 pattern: when args.fp4_config.enabled and
args.fp4_config.fp4_model_init_kwargs.enabled are true, call
transformer_engine.pytorch.quantized_model_init(...) (or the appropriate FP4
init API) with args.fp4_config.fp4_model_init_kwargs before model training;
update any related code paths that reference
fp4_config.fp4_recipe/fp4_format/fp4_recipe_kwargs to ensure consistent
behavior.

In `@bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py`:
- Around line 130-133: The assertion message for the padded_vocab_size check is
too long; update the assert in the block that checks self.padded_vocab_size and
self.vocab_size to keep lines <=119 characters by moving the long f-string into
a shorter expression (e.g., build the message in a local variable like msg =
(f"padded_vocab_size ({self.padded_vocab_size}) must be greater than or equal to
" f"vocab_size ({self.vocab_size})") or use implicit string concatenation/split
across lines) and then call assert self.padded_vocab_size >= self.vocab_size,
msg; keep references to self.padded_vocab_size and self.vocab_size so the
semantic check and error content remain unchanged.

In `@bionemo-recipes/recipes/esm2_native_te/quantization.py`:
- Around line 86-94: The temp YAML file created with "temp_file =
tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False)" is left on
disk because you return temp_file.name; change this by either (A) making the
function return the config contents or a context-managed path (use
tempfile.NamedTemporaryFile with delete=True or write via
tempfile.TemporaryDirectory and yield the path) so the file is removed
automatically, or (B) accept a caller-provided "quant_log_dir" or "cleanup" flag
and write the file into that deterministic log directory (or delete the temp
file before returning when appropriate). Update the function's docstring and any
callers to reflect whether the caller is responsible for cleanup, and adjust
logging (logger.info uses temp_file.name) to log the deterministic path or note
that the file is ephemeral. Ensure you modify the code around the
NamedTemporaryFile usage and the return value accordingly.

In `@bionemo-recipes/recipes/esm2_native_te/README.md`:
- Line 106: Replace the vague anchor text "here" with a descriptive label that
explains the target, e.g., change "and
[here](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html)"
to something like "and the NVIDIA Transformer Engine FP8 primer" (update the
anchor text around the existing URL), so the README sentence reads clearly and
satisfies markdown linting.
- Around line 80-83: Fix the wording in the low-precision benchmark paragraph:
change "low precision" to "low-precision" throughout, add missing period after
"etc" ("etc."), correct the typo "outweights" to "outweighs", and rephrase the
sentence that reads "the cost to quantize activations from high precision to low
precision outweights the benefits of performing matrix multiplication with low
precision" to a clearer form (e.g., "the cost to quantize activations from
high-precision to low-precision outweighs the benefits of using low-precision
matrix multiplication") so the paragraph reads smoothly and consistently.

In `@bionemo-recipes/recipes/esm2_native_te/train_ddp.py`:
- Around line 62-64: Replace the silent warning when args.fp4_config.enabled is
true with a fail-fast guard: in the block that currently calls logger.warning
(the check of args.fp4_config.enabled), either raise a clear RuntimeError to
stop execution or require an explicit override flag (e.g.,
args.allow_experimental_nvfp4_ddp) before proceeding; update the check to
validate that if fp4 is enabled and DDP is in use (the current DDP launch path),
then if not args.allow_experimental_nvfp4_ddp raise an error with a message
explaining NVFP4+DDP is unsupported and how to opt into experimental mode.
Ensure you reference and change the condition around args.fp4_config.enabled and
the logger.warning call so the run is blocked unless the explicit override is
provided.

In `@bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py`:
- Around line 214-234: Add unit tests validating the per-layer quantization
context selection in the loop that uses layer_number_quantized_recipe_map:
ensure that when layer_number_quantized_recipe_map returns an FP8 recipe the
code sets fp_context to nullcontext(), when it returns an FP4 recipe it uses
transformer_engine.pytorch.autocast(enabled=True, recipe=fp_recipe), and for
None/other recipes it uses transformer_engine.pytorch.autocast(enabled=False);
add tests covering output_hidden_states path as well (all_hidden_states
behavior) and include edge cases (missing map, unexpected recipe types), and
also correct the typo in the TODO comment from "funciton" to "function".

---

Nitpick comments:
In `@bionemo-recipes/models/esm2/modeling_esm_te.py`:
- Around line 60-67: FP4_RECIPES is defined as a single value whereas
FP8_RECIPES is a tuple; make FP4_RECIPES a tuple for consistency by wrapping
transformer_engine.common.recipe.NVFP4BlockScaling in a one-element tuple (use a
trailing comma) so the constant mirrors FP8_RECIPES and any tuple-based handling
of recipes (reference FP4_RECIPES and FP8_RECIPES names).

In `@bionemo-recipes/recipes/esm2_native_te/.dockerignore`:
- Around line 33-34: The ignore pattern "j/" in .dockerignore matches any
directory named "j" at any depth; change it to "/j/" to anchor it to the
repository root so only the top-level scratch dir is ignored. Update the pattern
"j/" to "/j/" (preserve the trailing slash) in the .dockerignore entry to avoid
accidentally excluding nested directories named "j".

In `@bionemo-recipes/recipes/esm2_native_te/dataset.py`:
- Around line 60-62: The call to AutoTokenizer.from_pretrained uses a redundant
conditional for the revision argument; replace revision=tokenizer_revision if
tokenizer_revision else None with simply revision=tokenizer_revision in the
AutoTokenizer.from_pretrained(...) call so the revision parameter directly uses
the tokenizer_revision variable.

In `@bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml`:
- Around line 104-107: The default for use_fp32_master_weights is set to null
which may be unexpected when later checked as args.use_fp32_master_weights;
update the defaults.yaml to document that null is intentional and treated as
falsy (or change the default to false) and add a short comment next to
use_fp32_master_weights explaining that the training script expects a
boolean-like value and that null will be treated as false by the conditional
using args.use_fp32_master_weights; ensure any consumers (e.g., the code that
checks args.use_fp32_master_weights) handle null explicitly if you want
different behavior.

In `@bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml`:
- Around line 9-12: Add a blank line between the dataset block and the WandB
section to match other config files: locate the dataset section (keys dataset,
micro_batch_size, tokenizer_revision) and insert a single empty line before the
existing WandB comment/section (the "# WandB config" or the wandb_init_args
section) so formatting is consistent with L1_3B.yaml and L1_15B_perf_test.yaml.

In `@bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py`:
- Around line 604-609: Replace the hardcoded mask ratio literal by a named class
constant and use it wherever needed: define a class-level constant (e.g.,
MASK_RATIO_TRAIN = 0.15 * 0.8) on the model class, then update the local
variable usage in the method that currently sets mask_ratio_train and in the
similar block at lines 625-637 to read from that constant (self.MASK_RATIO_TRAIN
or ClassName.MASK_RATIO_TRAIN) so scale_factor and embedding scaling use the
named constant for clarity and potential configurability.
- Around line 228-234: Extract the per-layer FP context selection into a small
function (e.g., get_fp_context(fp_recipe)) that returns nullcontext when
fp_recipe is in FP8_RECIPES, returns
transformer_engine.pytorch.autocast(enabled=True, recipe=fp_recipe) when
fp_recipe is in FP4_RECIPES, and returns
transformer_engine.pytorch.autocast(enabled=False) otherwise; then add a unit
test file that parametrizes several fp_recipe values (members of FP8_RECIPES,
FP4_RECIPES, and a default/None case) and asserts that get_fp_context returns
the expected context object type/behavior (e.g., is nullcontext for FP8 cases
and an autocast context for FP4/default) to validate the per-layer FP context
selection logic.

In `@bionemo-recipes/recipes/esm2_native_te/quantization.py`:
- Around line 61-62: The file open call using config_file should specify an
explicit encoding to avoid platform-dependent behavior; update the open(...) in
the block that reads the YAML (the with open(config_file, "r") as f: config =
yaml.safe_load(f) statement) to include encoding="utf-8" so the YAML is read
consistently across platforms.
- Around line 136-157: Replace the manual boilerplate class QuantizationLayers
with a dataclass: add "from dataclasses import dataclass" and annotate the class
with `@dataclass`, convert the four constructor args (fp8_layers_0indexed,
fp4_layers_0indexed, fp8_layers_1indexed, fp4_layers_1indexed) to dataclass
fields using Optional[list[int]] types (or list[int] | None) and remove the
explicit __init__; keep the existing class docstring and attribute names so
__repr__ and __eq__ are provided automatically.

In `@bionemo-recipes/recipes/esm2_native_te/tests/test_train.py`:
- Around line 146-158: Rename test identifiers and docstring to reflect the
config rename from fp8_stats_config to quant_stats_config: update the test
function name test_sanity_ddp_fp8_stats_logging to
test_sanity_ddp_quant_stats_logging, change the fp8_log_dir variable to
quant_log_dir (or similar), and update the docstring "FP8 stats logging" to
"quant stats logging" while keeping all uses of quant_stats_config and the
existing assertions intact so the test still validates FP8 behavior under the
new config name.

In `@bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py`:
- Around line 60-67: FP4_RECIPES is assigned a single class
(transformer_engine.common.recipe.NVFP4BlockScaling) while FP8_RECIPES is a
tuple; make FP4_RECIPES a tuple for consistency and to allow adding more entries
later — update the FP4_RECIPES assignment to use a tuple containing
NVFP4BlockScaling (e.g., FP4_RECIPES =
(transformer_engine.common.recipe.NVFP4BlockScaling,)) so code that expects a
sequence of recipe classes (similar to FP8_RECIPES) will work uniformly.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bd72d88 and e45116a.

⛔ Files ignored due to path filters (5)
  • docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-mxfp8-10node-conv.svg is excluded by !**/*.svg
  • docs/docs/assets/images/esm2/esm2_low_precision/esm2-15b-b300-nvfp4-10node-conv.svg is excluded by !**/*.svg
  • docs/docs/assets/images/esm2/esm2_low_precision/esm2-3b-b200-mxfp8-6node-conv.png is excluded by !**/*.png
  • docs/docs/assets/images/esm2/esm2_low_precision/esm2-3b-b200-nvfp4-6node-conv.png is excluded by !**/*.png
  • docs/docs/assets/images/esm2/esm2_low_precision/esm2_8gpu_tflops.png is excluded by !**/*.png
📒 Files selected for processing (24)
  • .vscode/settings.json
  • bionemo-recipes/models/esm2/modeling_esm_te.py
  • bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py
  • bionemo-recipes/recipes/esm2_native_te/.dockerignore
  • bionemo-recipes/recipes/esm2_native_te/README.md
  • bionemo-recipes/recipes/esm2_native_te/dataset.py
  • bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py
  • bionemo-recipes/recipes/esm2_native_te/fp4_debugging_stats.yaml
  • bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml
  • bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_15B_perf_test.yaml
  • bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_3B.yaml
  • bionemo-recipes/recipes/esm2_native_te/hydra_config/L1_650M.yaml
  • bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml
  • bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py
  • bionemo-recipes/recipes/esm2_native_te/perf_logger.py
  • bionemo-recipes/recipes/esm2_native_te/quantization.py
  • bionemo-recipes/recipes/esm2_native_te/requirements.txt
  • bionemo-recipes/recipes/esm2_native_te/tests/test_quantization.py
  • bionemo-recipes/recipes/esm2_native_te/tests/test_train.py
  • bionemo-recipes/recipes/esm2_native_te/train_ddp.py
  • bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py
  • bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py
  • bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py
  • bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py

@jomitchellnv jomitchellnv changed the title NVFP4 and MXFP8 integrations ESM2 NVFP4 and MXFP8 support and documentation update. Mar 3, 2026
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
- removes nsys
- adds unit tests for layer-wise quantization

Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
@jomitchellnv
Copy link
Collaborator Author

@coderabbitai resolve

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 4, 2026

✅ Actions performed

Comments resolved. Auto-approval is disabled; enable reviews.request_changes_workflow to approve automatically.

@jomitchellnv
Copy link
Collaborator Author

@coderabbitai resolve

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 4, 2026

✅ Actions performed

Comments resolved. Auto-approval is disabled; enable reviews.request_changes_workflow to approve automatically.

@jomitchellnv
Copy link
Collaborator Author

/ok to test 8c53527

@jomitchellnv
Copy link
Collaborator Author

@coderabbitai review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 4, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py (1)

69-77: Gate FP8 recipe construction behind fp8_config.enabled.

fp8_recipe is instantiated even when FP8 is disabled. This can make BF16/FP4-only runs fail on irrelevant FP8 recipe config errors.

♻️ Suggested change
-    fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
-        fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
-    )
+    fp8_recipe = None
+    if args.fp8_config.enabled:
+        fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
+            fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py` around lines 69 - 77,
fp8_recipe is always instantiated even when FP8 is disabled; change the logic to
only construct fp8_recipe when args.fp8_config.enabled is true (mirror how
fp4_recipe is handled). Specifically, initialize fp8_recipe = None and wrap the
hydra.utils.get_class(...) construction in an if args.fp8_config.enabled: block
so that fp8_recipe is only created when enabled, referencing the existing
symbols fp8_recipe and args.fp8_config.enabled to locate and modify the code.
bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py (1)

87-90: Gate FP8 recipe creation on fp8_config.enabled here as well.

This avoids requiring FP8 recipe validity for runs that only use BF16/FP4.

♻️ Suggested change
-    fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
-        fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
-    )
+    fp8_recipe = None
+    if args.fp8_config.enabled:
+        fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
+            fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py` around lines 87 -
90, The FP8 recipe is created unconditionally which forces FP8 recipe validation
even when FP8 is disabled; wrap the hydra.utils.get_class(...) call that
constructs fp8_recipe in a conditional that checks args.fp8_config.enabled and
only builds fp8_recipe (using Format[args.fp8_config.fp8_format] and
args.fp8_config.fp8_recipe_kwargs) when enabled, otherwise set fp8_recipe to
None or skip creation so BF16/FP4-only runs don't require FP8 recipe validity.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@bionemo-recipes/models/esm2/tests/test_layer_quantization.py`:
- Around line 16-242: Add a golden-value parity test that runs the
TransformerEngine ESM encoder (NVEsmEncoder) and the reference ESM encoder on
the same deterministic input/seed and asserts numerical parity (e.g., final
token logits or pooled embeddings) within a small tolerance; create a new test
function (e.g., test_te_vs_reference_golden_value_parity) in this module that
uses torch.manual_seed, a small random input tensor on CUDA, constructs an
NVEsmEncoder via NVEsmConfig and constructs the reference ESM model (import the
reference model used in the repo), runs both forward passes with identical
settings, and asserts outputs are close with pytest.approx or torch.allclose;
ensure the test uses the existing encoder fixture pattern/device and keeps the
comparison deterministic and tolerant to tiny numeric differences.

In `@bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py`:
- Around line 213-222: In initialize_quantization, validate fp8_layers and
fp4_layers before applying them: convert to sets (fp8_layers_set,
fp4_layers_set), check every layer id is an int within 0..len(self.layers)-1 and
raise a ValueError if any id is out of range, check for overlap by computing
intersection = fp8_layers_set & fp4_layers_set and raise a ValueError if
non-empty, and optionally ensure inputs are unique/convertible to int; only
after these checks populate self._layer_precision using range(len(self.layers)).

In `@bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py`:
- Around line 78-81: The FP8 recipe is being constructed unconditionally
(fp8_recipe via hydra.utils.get_class and Format[args.fp8_config.fp8_format])
even when FP8 is disabled; wrap that construction in an if
args.fp8_config.enabled: guard (same pattern used for fp4_recipe) so fp8_recipe
is only created when args.fp8_config.enabled is true, and ensure any references
to fp8_recipe are only used within that guarded block or handled when disabled.

---

Nitpick comments:
In `@bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py`:
- Around line 87-90: The FP8 recipe is created unconditionally which forces FP8
recipe validation even when FP8 is disabled; wrap the hydra.utils.get_class(...)
call that constructs fp8_recipe in a conditional that checks
args.fp8_config.enabled and only builds fp8_recipe (using
Format[args.fp8_config.fp8_format] and args.fp8_config.fp8_recipe_kwargs) when
enabled, otherwise set fp8_recipe to None or skip creation so BF16/FP4-only runs
don't require FP8 recipe validity.

In `@bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py`:
- Around line 69-77: fp8_recipe is always instantiated even when FP8 is
disabled; change the logic to only construct fp8_recipe when
args.fp8_config.enabled is true (mirror how fp4_recipe is handled).
Specifically, initialize fp8_recipe = None and wrap the
hydra.utils.get_class(...) construction in an if args.fp8_config.enabled: block
so that fp8_recipe is only created when enabled, referencing the existing
symbols fp8_recipe and args.fp8_config.enabled to locate and modify the code.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 18a61b77-3157-42be-b9b8-05feda4cad9e

📥 Commits

Reviewing files that changed from the base of the PR and between e45116a and 8c53527.

📒 Files selected for processing (14)
  • bionemo-recipes/models/esm2/modeling_esm_te.py
  • bionemo-recipes/models/esm2/tests/test_layer_quantization.py
  • bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py
  • bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py
  • bionemo-recipes/recipes/esm2_native_te/fp8_debugging_stats.yaml
  • bionemo-recipes/recipes/esm2_native_te/hydra_config/defaults.yaml
  • bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py
  • bionemo-recipes/recipes/esm2_native_te/tests/test_distributed_checkpointing.py
  • bionemo-recipes/recipes/esm2_native_te/train_ddp.py
  • bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py
  • bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py
  • bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py
  • bionemo-recipes/recipes/esm2_native_te/train_mfsdp.py
  • bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py

Comment on lines +16 to +242
"""Unit tests for NVEsmEncoder.initialize_quantization and get_layer_autocast."""

from contextlib import nullcontext
from unittest.mock import patch

import pytest
import transformer_engine.common.recipe
import transformer_engine.pytorch

from modeling_esm_te import NVEsmConfig, NVEsmEncoder


@pytest.fixture
def encoder():
"""Create a small NVEsmEncoder on CUDA for testing."""
config = NVEsmConfig(
hidden_size=320,
intermediate_size=1280,
num_hidden_layers=6,
num_attention_heads=20,
max_position_embeddings=1026,
)
return NVEsmEncoder(config)


class TestInitializeQuantization:
"""Tests for NVEsmEncoder.initialize_quantization."""

def test_all_fp8(self, encoder):
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
encoder.initialize_quantization(
fp8_layers=[0, 1, 2, 3, 4, 5],
fp4_layers=None,
fp8_recipe=fp8_recipe,
fp4_recipe=None,
)
assert encoder._fp8_recipe is fp8_recipe
assert encoder._fp4_recipe is None
assert all(encoder._layer_precision[i] == "fp8" for i in range(6))

def test_all_fp4(self, encoder):
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
encoder.initialize_quantization(
fp8_layers=None,
fp4_layers=[0, 1, 2, 3, 4, 5],
fp8_recipe=None,
fp4_recipe=fp4_recipe,
)
assert encoder._fp8_recipe is None
assert encoder._fp4_recipe is fp4_recipe
assert all(encoder._layer_precision[i] == "fp4" for i in range(6))

def test_all_bf16(self, encoder):
encoder.initialize_quantization(
fp8_layers=None,
fp4_layers=None,
fp8_recipe=None,
fp4_recipe=None,
)
assert all(encoder._layer_precision[i] is None for i in range(6))

def test_mixed_fp8_fp4(self, encoder):
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
encoder.initialize_quantization(
fp8_layers=[0, 1, 2],
fp4_layers=[3, 4, 5],
fp8_recipe=fp8_recipe,
fp4_recipe=fp4_recipe,
)
for i in range(3):
assert encoder._layer_precision[i] == "fp8"
for i in range(3, 6):
assert encoder._layer_precision[i] == "fp4"

def test_mixed_fp8_bf16(self, encoder):
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
encoder.initialize_quantization(
fp8_layers=[0, 2, 4],
fp4_layers=None,
fp8_recipe=fp8_recipe,
fp4_recipe=None,
)
assert encoder._layer_precision[0] == "fp8"
assert encoder._layer_precision[1] is None
assert encoder._layer_precision[2] == "fp8"
assert encoder._layer_precision[3] is None
assert encoder._layer_precision[4] == "fp8"
assert encoder._layer_precision[5] is None

def test_mixed_all_three(self, encoder):
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
encoder.initialize_quantization(
fp8_layers=[0, 1],
fp4_layers=[4, 5],
fp8_recipe=fp8_recipe,
fp4_recipe=fp4_recipe,
)
assert encoder._layer_precision[0] == "fp8"
assert encoder._layer_precision[1] == "fp8"
assert encoder._layer_precision[2] is None # BF16
assert encoder._layer_precision[3] is None # BF16
assert encoder._layer_precision[4] == "fp4"
assert encoder._layer_precision[5] == "fp4"

def test_empty_lists_treated_as_none(self, encoder):
encoder.initialize_quantization(
fp8_layers=[],
fp4_layers=[],
fp8_recipe=None,
fp4_recipe=None,
)
assert all(encoder._layer_precision[i] is None for i in range(6))

def test_covers_all_layers(self, encoder):
encoder.initialize_quantization(
fp8_layers=[0],
fp4_layers=None,
fp8_recipe=transformer_engine.common.recipe.DelayedScaling(),
fp4_recipe=None,
)
assert len(encoder._layer_precision) == 6

def test_recipes_stored_as_attributes(self, encoder):
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
encoder.initialize_quantization(
fp8_layers=[0],
fp4_layers=[1],
fp8_recipe=fp8_recipe,
fp4_recipe=fp4_recipe,
)
# Recipes are stored once, not duplicated per-layer in the map.
assert encoder._fp8_recipe is fp8_recipe
assert encoder._fp4_recipe is fp4_recipe
# The map only contains strings, not recipe objects.
for v in encoder._layer_precision.values():
assert v is None or isinstance(v, str)


class TestGetLayerAutocast:
"""Tests for NVEsmEncoder.get_layer_autocast."""

def test_fp8_layer_returns_nullcontext(self, encoder):
encoder.initialize_quantization(
fp8_layers=[0],
fp4_layers=None,
fp8_recipe=transformer_engine.common.recipe.DelayedScaling(),
fp4_recipe=None,
)
ctx = encoder.get_layer_autocast(0)
assert isinstance(ctx, nullcontext)

def test_fp4_layer_returns_te_autocast(self, encoder):
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
encoder.initialize_quantization(
fp8_layers=None,
fp4_layers=[0],
fp8_recipe=None,
fp4_recipe=fp4_recipe,
)
with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
mock_autocast.return_value = "fp4_context"
ctx = encoder.get_layer_autocast(0)
mock_autocast.assert_called_once_with(enabled=True, recipe=fp4_recipe)
assert ctx == "fp4_context"

def test_bf16_layer_returns_te_autocast_disabled(self, encoder):
encoder.initialize_quantization(
fp8_layers=None,
fp4_layers=None,
fp8_recipe=None,
fp4_recipe=None,
)
with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
mock_autocast.return_value = "bf16_context"
ctx = encoder.get_layer_autocast(0)
mock_autocast.assert_called_once_with(enabled=False)
assert ctx == "bf16_context"

def test_uninitialized_defaults_to_bf16(self, encoder):
"""When initialize_quantization was never called, all layers default to BF16."""
with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
mock_autocast.return_value = "bf16_context"
ctx = encoder.get_layer_autocast(0)
mock_autocast.assert_called_once_with(enabled=False)
assert ctx == "bf16_context"

def test_mixed_layers_return_correct_contexts(self, encoder):
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
encoder.initialize_quantization(
fp8_layers=[0, 1],
fp4_layers=[2, 3],
fp8_recipe=fp8_recipe,
fp4_recipe=fp4_recipe,
)
# FP8 layers -> nullcontext
assert isinstance(encoder.get_layer_autocast(0), nullcontext)
assert isinstance(encoder.get_layer_autocast(1), nullcontext)

# FP4 and BF16 layers -> te.pytorch.autocast (not nullcontext)
with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
mock_autocast.return_value = "fp4_context"
encoder.get_layer_autocast(2)
mock_autocast.assert_called_with(enabled=True, recipe=fp4_recipe)

with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
mock_autocast.return_value = "bf16_context"
encoder.get_layer_autocast(4)
mock_autocast.assert_called_with(enabled=False)

def test_layer_precision_map_is_pickleable(self, encoder):
"""The _layer_precision map should be trivially pickleable (only strings/None)."""
import pickle

fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
encoder.initialize_quantization(
fp8_layers=[0, 1],
fp4_layers=[2, 3],
fp8_recipe=fp8_recipe,
fp4_recipe=fp4_recipe,
)
roundtripped = pickle.loads(pickle.dumps(encoder._layer_precision))
assert roundtripped == encoder._layer_precision
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add at least one TE-vs-reference golden-value parity test in this module.

These tests cover routing/context behavior, but they do not assert numerical parity between the TE model and the reference ESM model for a fixed input/seed.

As per coding guidelines: "In bionemo-recipes/models/, create golden value tests proving that the TransformerEngine model matches the reference model".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/esm2/tests/test_layer_quantization.py` around lines 16
- 242, Add a golden-value parity test that runs the TransformerEngine ESM
encoder (NVEsmEncoder) and the reference ESM encoder on the same deterministic
input/seed and asserts numerical parity (e.g., final token logits or pooled
embeddings) within a small tolerance; create a new test function (e.g.,
test_te_vs_reference_golden_value_parity) in this module that uses
torch.manual_seed, a small random input tensor on CUDA, constructs an
NVEsmEncoder via NVEsmConfig and constructs the reference ESM model (import the
reference model used in the repo), runs both forward passes with identical
settings, and asserts outputs are close with pytest.approx or torch.allclose;
ensure the test uses the existing encoder fixture pattern/device and keeps the
comparison deterministic and tolerant to tiny numeric differences.

Comment on lines +213 to +222
fp8_layers_set = set(fp8_layers) if fp8_layers else set()
fp4_layers_set = set(fp4_layers) if fp4_layers else set()
self._layer_precision = {}
for layer_number in range(len(self.layers)):
if layer_number in fp8_layers_set:
self._layer_precision[layer_number] = "fp8"
elif layer_number in fp4_layers_set:
self._layer_precision[layer_number] = "fp4"
else:
self._layer_precision[layer_number] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate overlaps and bounds in initialize_quantization.

Right now, overlapping or out-of-range layer IDs are silently accepted. This can hide config mistakes and route layers to unintended precision.

♻️ Suggested fix
         self._fp8_recipe = fp8_recipe
         self._fp4_recipe = fp4_recipe
         fp8_layers_set = set(fp8_layers) if fp8_layers else set()
         fp4_layers_set = set(fp4_layers) if fp4_layers else set()
+        overlap = fp8_layers_set & fp4_layers_set
+        if overlap:
+            raise ValueError(f"fp8_layers and fp4_layers overlap: {sorted(overlap)}")
+
+        valid_layers = set(range(len(self.layers)))
+        invalid = (fp8_layers_set | fp4_layers_set) - valid_layers
+        if invalid:
+            raise ValueError(
+                f"Layer indices out of range [0, {len(self.layers) - 1}]: {sorted(invalid)}"
+            )
+
         self._layer_precision = {}
         for layer_number in range(len(self.layers)):
             if layer_number in fp8_layers_set:
                 self._layer_precision[layer_number] = "fp8"
             elif layer_number in fp4_layers_set:
                 self._layer_precision[layer_number] = "fp4"
             else:
                 self._layer_precision[layer_number] = None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/modeling_esm_te.py` around lines 213 -
222, In initialize_quantization, validate fp8_layers and fp4_layers before
applying them: convert to sets (fp8_layers_set, fp4_layers_set), check every
layer id is an int within 0..len(self.layers)-1 and raise a ValueError if any id
is out of range, check for overlap by computing intersection = fp8_layers_set &
fp4_layers_set and raise a ValueError if non-empty, and optionally ensure inputs
are unique/convertible to int; only after these checks populate
self._layer_precision using range(len(self.layers)).

Comment on lines +78 to 81
# Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config.
fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify whether FP8 recipe construction is gated by fp8_config.enabled in train_ddp_cp.py.
cat -n bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py | sed -n '74,90p'
echo "---"
rg -n "fp8_recipe|fp8_config.enabled" -C3 bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py

Repository: NVIDIA/bionemo-framework

Length of output: 2742


Guard FP8 recipe construction behind fp8_config.enabled.

Lines 79–81 construct fp8_recipe unconditionally, even when FP8 is disabled. This contradicts the comment on line 78 and creates an asymmetry with fp4_recipe (lines 82–86), which is correctly guarded by if args.fp4_config.enabled:. Disabled FP8 runs will still fail if the FP8 config is invalid.

Suggested fix
+    fp8_recipe = None
+    if args.fp8_config.enabled:
-    fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
+        fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
             fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
-    )
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py` around lines 78 - 81,
The FP8 recipe is being constructed unconditionally (fp8_recipe via
hydra.utils.get_class and Format[args.fp8_config.fp8_format]) even when FP8 is
disabled; wrap that construction in an if args.fp8_config.enabled: guard (same
pattern used for fp4_recipe) so fp8_recipe is only created when
args.fp8_config.enabled is true, and ensure any references to fp8_recipe are
only used within that guarded block or handled when disabled.

Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
@jomitchellnv
Copy link
Collaborator Author

I also re-ran the code on a B300 node
image

Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant